{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Format DataFrame\n",
    "\n",
    "Be advised, this dataset (SKLearn's Forest Cover Types) can take a little while to download...\n",
    "\n",
    "This is a multi-class classification task, in which the target is label-encoded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(581012, 55)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x_0</th>\n",
       "      <th>x_1</th>\n",
       "      <th>x_2</th>\n",
       "      <th>x_3</th>\n",
       "      <th>x_4</th>\n",
       "      <th>x_5</th>\n",
       "      <th>x_6</th>\n",
       "      <th>x_7</th>\n",
       "      <th>x_8</th>\n",
       "      <th>x_9</th>\n",
       "      <th>...</th>\n",
       "      <th>x_45</th>\n",
       "      <th>x_46</th>\n",
       "      <th>x_47</th>\n",
       "      <th>x_48</th>\n",
       "      <th>x_49</th>\n",
       "      <th>x_50</th>\n",
       "      <th>x_51</th>\n",
       "      <th>x_52</th>\n",
       "      <th>x_53</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3247.0</td>\n",
       "      <td>289.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>268.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>1624.0</td>\n",
       "      <td>186.0</td>\n",
       "      <td>238.0</td>\n",
       "      <td>193.0</td>\n",
       "      <td>2525.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3200.0</td>\n",
       "      <td>46.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>162.0</td>\n",
       "      <td>45.0</td>\n",
       "      <td>1592.0</td>\n",
       "      <td>223.0</td>\n",
       "      <td>200.0</td>\n",
       "      <td>105.0</td>\n",
       "      <td>2254.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2368.0</td>\n",
       "      <td>48.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>277.0</td>\n",
       "      <td>121.0</td>\n",
       "      <td>1260.0</td>\n",
       "      <td>224.0</td>\n",
       "      <td>196.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>1237.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2828.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>417.0</td>\n",
       "      <td>73.0</td>\n",
       "      <td>1252.0</td>\n",
       "      <td>225.0</td>\n",
       "      <td>215.0</td>\n",
       "      <td>123.0</td>\n",
       "      <td>962.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2932.0</td>\n",
       "      <td>32.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>618.0</td>\n",
       "      <td>55.0</td>\n",
       "      <td>638.0</td>\n",
       "      <td>218.0</td>\n",
       "      <td>217.0</td>\n",
       "      <td>134.0</td>\n",
       "      <td>1092.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 55 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      x_0    x_1   x_2    x_3    x_4     x_5    x_6    x_7    x_8     x_9 ...  \\\n",
       "0  3247.0  289.0  12.0  268.0   40.0  1624.0  186.0  238.0  193.0  2525.0 ...   \n",
       "1  3200.0   46.0  17.0  162.0   45.0  1592.0  223.0  200.0  105.0  2254.0 ...   \n",
       "2  2368.0   48.0  19.0  277.0  121.0  1260.0  224.0  196.0   99.0  1237.0 ...   \n",
       "3  2828.0   50.0  11.0  417.0   73.0  1252.0  225.0  215.0  123.0   962.0 ...   \n",
       "4  2932.0   32.0  11.0  618.0   55.0   638.0  218.0  217.0  134.0  1092.0 ...   \n",
       "\n",
       "   x_45  x_46  x_47  x_48  x_49  x_50  x_51  x_52  x_53  y  \n",
       "0   0.0   1.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  1  \n",
       "1   0.0   1.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  1  \n",
       "2   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  3  \n",
       "3   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  2  \n",
       "4   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  1  \n",
       "\n",
       "[5 rows x 55 columns]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from sklearn.datasets import fetch_covtype\n",
    "\n",
    "data = fetch_covtype(shuffle=True, random_state=32)\n",
    "train_df = pd.DataFrame(data.data, columns=[\"x_{}\".format(_) for _ in range(data.data.shape[1])])\n",
    "train_df[\"y\"] = data.target\n",
    "\n",
    "print(train_df.shape)\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set Up Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cross-Experiment Key:   'WQMO2i1RnEaE7cguWwpBywkh25UKTgtwR12Z0LqWIUM='\n"
     ]
    }
   ],
   "source": [
    "from hyperparameter_hunter import Environment, CVExperiment\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "env = Environment(\n",
    "    train_dataset=train_df,\n",
    "    results_path=\"HyperparameterHunterAssets\",\n",
    "    target_column=\"y\",\n",
    "    metrics=dict(f1=lambda y_true, y_pred: f1_score(y_true, y_pred, average=\"micro\")),\n",
    "    cv_type=\"StratifiedKFold\",\n",
    "    cv_params=dict(n_splits=5, random_state=32),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that HyperparameterHunter has an active `Environment`, we can do two things:\n",
    "\n",
    "# 1. Perform Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<15:06:35> Validated Environment:  'WQMO2i1RnEaE7cguWwpBywkh25UKTgtwR12Z0LqWIUM='\n",
      "<15:06:35> Initialized Experiment: 'f2096258-17fd-47b4-a384-362a43cc8cbd'\n",
      "<15:06:35> Hyperparameter Key:     'Hyx-Jo5QIqiXxDRLIQjh5_uQ2JVsViCjaGhWzzoYpy4='\n",
      "<15:06:35> \n",
      "<15:06:44> F0.0 AVG:   OOF(f1=0.83622)  |  Time Elapsed: 9.24039 s\n",
      "<15:06:54> F0.1 AVG:   OOF(f1=0.83796)  |  Time Elapsed: 9.17901 s\n",
      "<15:07:03> F0.2 AVG:   OOF(f1=0.83635)  |  Time Elapsed: 9.43469 s\n",
      "<15:07:12> F0.3 AVG:   OOF(f1=0.83682)  |  Time Elapsed: 9.32817 s\n",
      "<15:07:22> F0.4 AVG:   OOF(f1=0.83370)  |  Time Elapsed: 9.22297 s\n",
      "<15:07:22> \n",
      "<15:07:22> FINAL:    OOF(f1=0.83621)  |  Time Elapsed: 46.77216 s\n",
      "<15:07:22> \n",
      "<15:07:22> Saving results for Experiment: 'f2096258-17fd-47b4-a384-362a43cc8cbd'\n"
     ]
    }
   ],
   "source": [
    "from lightgbm import LGBMClassifier\n",
    "\n",
    "experiment = CVExperiment(\n",
    "    model_initializer=LGBMClassifier,\n",
    "    model_init_params=dict(boosting_type=\"gbdt\", num_leaves=31, max_depth=-1, subsample=0.5),\n",
    "    model_extra_params=dict(\n",
    "        fit=dict(\n",
    "            feature_name=train_df.columns.values[:-1].tolist(),\n",
    "            categorical_feature=train_df.columns.values[11:-1].tolist(),\n",
    "        ),\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validated Environment with key: \"WQMO2i1RnEaE7cguWwpBywkh25UKTgtwR12Z0LqWIUM=\"\n",
      "\u001b[31mSaved Result Files\u001b[0m\n",
      "\u001b[31m_______________________________________________________________________________________\u001b[0m\n",
      " Step |       ID |   Time |      Value |   boosting_type |   num_leaves |   subsample | \n",
      "Experiments matching cross-experiment key/algorithm: 1\n",
      "Experiments fitting in the given space: 1\n",
      "Experiments matching current guidelines: 1\n",
      "    0 | f2096258 | 00m00s | \u001b[35m   0.83621\u001b[0m | \u001b[32m           gbdt\u001b[0m | \u001b[32m          31\u001b[0m | \u001b[32m     0.5000\u001b[0m | \n",
      "\u001b[31mHyperparameter Optimization\u001b[0m\n",
      "\u001b[31m_______________________________________________________________________________________\u001b[0m\n",
      " Step |       ID |   Time |      Value |   boosting_type |   num_leaves |   subsample | \n",
      "    1 | 00708066 | 01m23s |    0.76716 |            dart |           15 |      0.4684 | \n",
      "    2 | dd0307d2 | 00m52s |    0.83191 |            gbdt |           29 |      0.5947 | \n",
      "    3 | e3b29434 | 01m20s |    0.76080 |            dart |           13 |      0.4824 | \n",
      "    4 | 655b0837 | 00m48s |    0.81328 |            gbdt |           21 |      0.5386 | \n",
      "    5 | 6f175f80 | 00m39s |    0.78835 |            gbdt |           13 |      0.4253 | \n",
      "    6 | 58f93d09 | 01m46s |    0.80076 |            dart |           30 |      0.5206 | \n",
      "    7 | 8bdd341f | 01m21s |    0.76753 |            dart |           15 |      0.3583 | \n",
      "    8 | b298abd1 | 01m57s |    0.81578 |            dart |           40 |      0.6065 | \n",
      "    9 | 73bcdf23 | 01m50s |    0.80087 |            dart |           30 |      0.6893 | \n",
      "   10 | 258a3d2d | 00m49s | \u001b[35m   0.83957\u001b[0m | \u001b[32m           gbdt\u001b[0m | \u001b[32m          33\u001b[0m | \u001b[32m     0.4972\u001b[0m | \n",
      "Optimization loop completed in 0:12:49.753418\n",
      "Best score was 0.8395695785973439 from Experiment \"258a3d2d-ad66-43be-8395-cdecbd1aea1a\"\n"
     ]
    }
   ],
   "source": [
    "from hyperparameter_hunter import RandomForestOptPro, Real, Integer, Categorical\n",
    "\n",
    "optimizer = RandomForestOptPro(iterations=10, random_state=32)\n",
    "\n",
    "optimizer.forge_experiment(\n",
    "    model_initializer=LGBMClassifier,\n",
    "    model_init_params=dict(\n",
    "        boosting_type=Categorical([\"gbdt\", \"dart\"]),\n",
    "        num_leaves=Integer(10, 40),\n",
    "        max_depth=-1,\n",
    "        subsample=Real(0.3, 0.7),\n",
    "    ),\n",
    "    model_extra_params=dict(\n",
    "        fit=dict(\n",
    "            feature_name=train_df.columns.values[:-1].tolist(),\n",
    "            categorical_feature=train_df.columns.values[11:-1].tolist(),\n",
    "        ),\n",
    "    ),\n",
    ")\n",
    "\n",
    "optimizer.go()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice, `optimizer` recognizes our earlier `experiment`'s hyperparameters fit inside the search space/guidelines set for `optimizer`.\n",
    "\n",
    "Then, when optimization is started, it automatically learns from `experiment`'s results - without any extra work for us!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "varInspector": {
   "cols": {
    "lenName": 16.0,
    "lenType": 16.0,
    "lenVar": 40.0
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}